from collections import Counter
import pybedtools
from matplotlib import patches
from matplotlib.colors import LinearSegmentedColormap
from pylab import *
from numpy import *
from scipy.stats import pearsonr, spearmanr

from rpy2 import robjects
import rpy2.robjects.numpy2ri
robjects.numpy2ri.activate()
from rpy2.robjects.packages import importr
deseq = importr('DESeq2')
print("Using DESeq2 version %s" % deseq.__version__)
glmGamPoi = importr('glmGamPoi')


import warnings
warnings.filterwarnings('error')






# genes
categories_skip = set(["antisense", "prompt", "antisense_distal", "antisense_distal_upstream", "roadmap_dyadic", "roadmap_enhancer", "FANTOM5_enhancer", "novel_enhancer_CAGE", "novel_enhancer_HiSeq", "novel_enhancer_StartSeq", "other"])
categories_keep = set(["sense_proximal", "sense_upstream", "sense_distal", "sense_distal_upstream"])

# enhancers
# categories_skip = set(["sense_proximal", "sense_upstream", "sense_distal", "sense_distal_upstream", "antisense", "prompt", "antisense_distal", "antisense_distal_upstream", "roadmap_dyadic", "other"])
# categories_keep = set(["roadmap_enhancer", "FANTOM5_enhancer", "novel_enhancer_CAGE", "novel_enhancer_HiSeq"])

# other
# categories_skip = set(["sense_proximal", "sense_upstream", "sense_distal", "sense_distal_upstream", "antisense", "prompt", "antisense_distal", "antisense_distal_upstream", "roadmap_dyadic", "roadmap_enhancer", "FANTOM5_enhancer", "novel_enhancer_CAGE", "novel_enhancer_HiSeq"])
# categories_keep = set(["other"])

# antisense
# categories_skip = set(["sense_proximal", "sense_upstream", "sense_distal", "sense_distal_upstream", "antisense_distal", "antisense_distal_upstream", "roadmap_dyadic", "roadmap_enhancer", "FANTOM5_enhancer", "novel_enhancer_CAGE", "novel_enhancer_HiSeq", "other"])
# categories_keep = set(["antisense", "prompt"])

# CASPAR
# categories_skip = set(["sense_proximal", "sense_upstream", "sense_distal", "sense_distal_upstream", "prompt", "antisense_distal", "antisense_distal_upstream", "roadmap_dyadic", "roadmap_enhancer", "FANTOM5_enhancer", "novel_enhancer_CAGE", "novel_enhancer_HiSeq", "other"])
# categories_keep = set(["antisense"])


def read_ppvalues():
    filename = "peaks.HiSeq_StartSeq.gff"
    print("Reading", filename)
    lines = pybedtools.BedTool(filename)
    ppvalues = {}
    for line in lines:
        feature = line.fields[2]
        category = feature
        if category in categories_skip:
            continue
        if category not in categories_keep:
            raise Exception("Unknown category %s" % category)
        peak = "%s_%d-%d_%s" % (line.chrom, line.start, line.end, line.strand)
        ppvalue = float(line.attrs['ppvalue'])
        ppvalues[peak] = -ppvalue
    print("Read %d transcription initiation peaks" % len(ppvalues))
    return ppvalues

def read_expression_data(ppvalues):
    peaks = list(ppvalues.keys())
    filename = "peaks.HiSeq_StartSeq.expression.txt"
    print("Reading", filename)
    handle = open(filename)
    line = next(handle)
    words = line.split()
    assert words[0] == 'peak'
    assert words[1] == 'HiSeq_t00_r1'
    assert words[2] == 'HiSeq_t00_r2'
    assert words[3] == 'HiSeq_t00_r3'
    assert words[4] == 'HiSeq_t01_r1'
    assert words[5] == 'HiSeq_t01_r2'
    assert words[6] == 'HiSeq_t04_r1'
    assert words[7] == 'HiSeq_t04_r2'
    assert words[8] == 'HiSeq_t04_r3'
    assert words[9] == 'HiSeq_t12_r1'
    assert words[10] == 'HiSeq_t12_r2'
    assert words[11] == 'HiSeq_t12_r3'
    assert words[12] == 'HiSeq_t24_r1'
    assert words[13] == 'HiSeq_t24_r2'
    assert words[14] == 'HiSeq_t24_r3'
    assert words[15] == 'HiSeq_t96_r1'
    assert words[16] == 'HiSeq_t96_r2'
    assert words[17] == 'HiSeq_t96_r3'
    assert words[18] == 'StartSeq_SRR7071452'
    assert words[19] == 'StartSeq_SRR7071453'
    hiseq_indices = []
    startseq_indices = []
    for index, word in enumerate(words[1:]):
        experiment, condition = word.split("_", 1)
        if experiment == "HiSeq":
            hiseq_indices.append(index)
        elif experiment == "StartSeq":
            startseq_indices.append(index)
        else:
            raise Exception("Unknown experiment %s" % experiment)
    hiseq_indices = array(hiseq_indices)
    startseq_indices = array(startseq_indices)
    assert len(hiseq_indices) + len(startseq_indices) == 19
    i = 0
    n = len(peaks)
    data = []
    for line in handle:
        words = line.split()
        assert len(words) == 20
        peak = words[0]
        row = array(words[1:20], float)
        if peak == peaks[i]:
            data.append(row)
            i += 1
            if i == n:
                break
    else:
        raise ValueError("Failed to find all peaks")
    handle.close()
    data = array(data)
    return hiseq_indices, startseq_indices, data

def estimate_normalization_factors(counts, cage_indices, hiseq_indices):
    n, m = shape(counts)
    conditions_dataset = [None] * m
    for index in cage_indices:
        conditions_dataset[index] = "CAGE"
    for index in hiseq_indices:
        conditions_dataset[index] = "HiSeq"
    metadata = {'dataset': robjects.StrVector(conditions_dataset),
               }
    dataframe = robjects.DataFrame(metadata)
    design = robjects.Formula("~ dataset")
    dds = deseq.DESeqDataSetFromMatrix(countData=counts,
                                       colData=dataframe,
                                       design=design)
    estimateSizeFactors = robjects.r['estimateSizeFactors']
    dds = estimateSizeFactors(dds)
    sizeFactors = robjects.r['sizeFactors']
    factors = sizeFactors(dds)
    return factors

def make_figure_scatter(hiseq_indices, startseq_indices, counts, factors, ppvalues):
    m = LinearSegmentedColormap.from_list("mycmap", ['green', 'lightgray', 'red'])
    peaks = list(ppvalues.keys())
    ppvalues = array(list(ppvalues.values()))
    tpm = counts / factors
    total = mean(sum(tpm, 0))  # should be 1 million
    tpm *= 1.e6 / total
    hiseq_tpm = mean(tpm[:, hiseq_indices], 1)
    startseq_tpm = mean(tpm[:, startseq_indices], 1)
    n = len(counts)
    assert len(hiseq_tpm) == n
    assert len(startseq_tpm) == n
    print("Plotting a scatter plot with %d points" % n)
    fig = figure(figsize=(5,5))
    axes([0.20, 0.20, 0.55, 0.55])
    scatter(startseq_tpm, hiseq_tpm, c=ppvalues, s=3, vmin=-6, vmax=+6, cmap=m)
    xscale('log')
    yscale('log')
    xticks(fontsize=8)
    yticks(fontsize=8)
    ylabel("Short capped RNA (single-end libraries),\naverage expression [tpm]", fontsize=8)
    xlabel("Start-Seq libraries, average expression [tpm]", fontsize=8)
    cax = axes([0.80, 0.20, 0.03, 0.55])
    cb = colorbar(cax=cax)
    cb.set_label("$-\\log(p) \\times \\mathrm{sign}$", fontsize=8)
    cb.ax.tick_params(labelsize=8)
    r, p = spearmanr(hiseq_tpm, startseq_tpm)
    print("HiSeq average expression vs StartSeq average expression: Spearman correlation across genes is %.2f (p = %g)" % (r, p))
    r, p = pearsonr(log(hiseq_tpm+1), log(startseq_tpm+1))
    print("HiSeq average expression vs StartSeq average expression: Pearson correlation across genes is %.2f (p = %g)" % (r, p))
    filename = "figure_sense_startseq_scatter.svg"
    print("Saving the scatterplot figure to %s" % filename)
    savefig(filename)
    filename = "figure_sense_startseq_scatter.png"
    print("Saving the scatterplot figure to %s" % filename)
    savefig(filename)

ppvalues = read_ppvalues()

hiseq_indices, startseq_indices, data = read_expression_data(ppvalues)

ppvalue_threshold = -log10(0.05)
counts = Counter()
for name, row in zip(ppvalues, data):
    key = []
    if sum(row[hiseq_indices]) > 1e-20:
        key.append("hiseq_expressed")
    else:
        key.append("hiseq_not_expressed")
    if sum(row[startseq_indices]) > 1e-20:
        key.append("startseq_expressed")
    else:
        key.append("startseq_not_expressed")
    if ppvalues[name] > +ppvalue_threshold:
        key.append("hiseq_significant")
    else:
        key.append("hiseq_not_significant")
    if ppvalues[name] < -ppvalue_threshold:
        key.append("startseq_significant")
    else:
        key.append("startseq_not_significant")
    key = tuple(key)
    counts[key] += 1

for key in counts:
    print(key, counts[key])


circle_hiseq_expressed = Circle((-1., 0), 3, color='red', alpha=0.2)
circle_startseq_expressed = Circle((+1., 0), 3, color='green', alpha=0.2)
circle_hiseq_significant = Circle((-2, 0), 1.75, color='red', alpha=0.2)
circle_startseq_significant = Circle((+2, 0), 1.75, color='green', alpha=0.2)

f = figure()
ax = subplot(111)

ax.add_patch(circle_hiseq_expressed)
ax.add_patch(circle_startseq_expressed)
ax.add_patch(circle_hiseq_significant)
ax.add_patch(circle_startseq_significant)

axis('square')
xlim(-5,5)
ylim(-5,5)


number1 = counts[('hiseq_expressed', 'startseq_expressed', 'hiseq_not_significant', 'startseq_not_significant')]
text(0, -2, str(number1), horizontalalignment='center', verticalalignment='center')
number2 = counts[('hiseq_not_expressed', 'startseq_expressed', 'hiseq_not_significant', 'startseq_not_significant')]
text(+1.95, -2.25, str(number2), horizontalalignment='center', verticalalignment='center')
number3 = counts[('hiseq_expressed', 'startseq_not_expressed', 'hiseq_not_significant', 'startseq_not_significant')]
text(-1.95, -2.25, str(number3), horizontalalignment='center', verticalalignment='center')
number4 = counts[('hiseq_expressed', 'startseq_expressed', 'hiseq_significant', 'startseq_not_significant')]
text(-1.2, 0, str(number4), horizontalalignment='center', verticalalignment='center')
number5 = counts[('hiseq_expressed', 'startseq_expressed', 'hiseq_not_significant', 'startseq_significant')]
text(+1.2, 0, str(number5), horizontalalignment='center', verticalalignment='center')
number6 = counts[('hiseq_expressed', 'startseq_not_expressed', 'hiseq_significant', 'startseq_not_significant')]
text(-2.9, 0, str(number6), horizontalalignment='center', verticalalignment='center')
number7 = counts[('hiseq_not_expressed', 'startseq_expressed', 'hiseq_not_significant', 'startseq_significant')]
text(+2.9, 0, str(number7), horizontalalignment='center', verticalalignment='center')
text(-2, 3, 'HiSeq', color='red', horizontalalignment='center', verticalalignment='bottom')
text(+2, 3, 'StartSeq', color='green', horizontalalignment='center', verticalalignment='bottom')

hiseq_expressed_patch = patches.Patch(color='red', alpha=0.2, label='Expressed as\nshort capped RNAs')
startseq_expressed_patch = patches.Patch(color='green', alpha=0.2, label='Expressed in\nStart-Seq libraries')
hiseq_significant_patch = patches.Patch(color='red', alpha=0.4, label='Significantly enriched\nin short capped RNA\n(single-end) libraries')
startseq_significant_patch = patches.Patch(color='green', alpha=0.4, label='Significantly enriched\nin Start-Seq libraries')

legend(handles=[hiseq_expressed_patch, hiseq_significant_patch, startseq_expressed_patch, startseq_significant_patch], fontsize=8, loc='lower center', ncol=2, bbox_to_anchor=(0.5,-0.05))


title("Gene-associated peaks")

axis('off')

filename = "figure_peak_startseq_venn_diagram.svg"
print("Saving figure to", filename)
savefig(filename)

filename = "figure_peak_startseq_venn_diagram.png"
print("Saving figure to", filename)
savefig(filename)

print("Total number of gene-associated peaks: %d" % len(ppvalues))
print("Expressed in HiSeq and StartSeq: %d" % (number1 + number4 + number5))
print("Expressed in HiSeq only: %d" % (number3 + number6))
print("Expressed in StartSeq only: %d" % (number2 + number7))
print("Significantly higher expression in HiSeq: %d" % (number4 + number6))
print("Significantly higher expression in StartSeq: %d" % (number5 + number7))


factors = estimate_normalization_factors(data, hiseq_indices, startseq_indices)

make_figure_scatter(hiseq_indices, startseq_indices, data, factors, ppvalues)
